{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# iTree visualization for a toy dataset\n",
    "\n",
    "This notebook shows how iTree, a decision tree used in Isolation Forest, finds outliers.\n",
    "We show it for a 2-D toy dataset and three different anomaly detection models: Isolation Forest, Pine Forest, and AAD. \n",
    "\n",
    "To run this notebook you need `graphviz` installed and `dot` command to be in your `$PATH` environment variable.\n",
    "You can install it to your system with `apt install graphviz` or `brew install graphviz` or the equivalent command for your system."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-14T14:28:42.632776Z",
     "start_time": "2024-03-14T14:28:42.478397Z"
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import subprocess\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "plt.rcParams['figure.figsize'] = figsize = (8, 6)\n",
    "plt.rcParams['figure.dpi'] = 90"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-14T14:28:42.664522Z",
     "start_time": "2024-03-14T14:28:42.633794Z"
    }
   },
   "outputs": [],
   "source": [
    "class TreeViz:\n",
    "    \"\"\"\n",
    "    Tree vizualization with matplotlib or graphviz.\n",
    "    \"\"\"\n",
    "    def __init__(self, tree, known_data=None, known_labels=None, data=None, labels=None):\n",
    "        self.tree = tree\n",
    "        self.known_data = known_data\n",
    "        self.known_labels = known_labels\n",
    "        self.data = data\n",
    "        self.labels = labels\n",
    "\n",
    "        if known_data is None:\n",
    "            self.known_data = np.empty((0, tree.n_features))\n",
    "            self.known_labels = np.empty((0,))\n",
    "\n",
    "    @staticmethod\n",
    "    def _node_label(reds, blues):\n",
    "        if reds + blues == 0:\n",
    "            return '\"\"'\n",
    "\n",
    "        n_float = np.sqrt((reds + blues) / 12)\n",
    "        columns = int(np.ceil(n_float * 4))\n",
    "\n",
    "        labels = ['<font color=\"red\">&#x2605;</font>'] * reds + ['<font color=\"blue\">&#x2605;</font>'] * blues\n",
    "        text = []\n",
    "        while len(labels) > 0:\n",
    "            text.append(''.join(labels[:columns]))\n",
    "            labels = labels[columns:]\n",
    "\n",
    "        return '<' + '<br/>'.join(text) + '>'\n",
    "\n",
    "    def _dot_walker(self, current, known_data, known_labels):\n",
    "        \"Walk trough the tree recursively and make a dot file for graphviz\"\n",
    "\n",
    "        tree = self.tree\n",
    "        text = []\n",
    "        if tree.children_left[current] != -1:\n",
    "            text.append(f' n{current} [label=\"\"];')\n",
    "            # If there are children nodes, render them too.\n",
    "\n",
    "            # We need to split data and labels\n",
    "            index = known_data[:, tree.feature[current]] <= tree.threshold[current]\n",
    "\n",
    "            # Left one\n",
    "            left = tree.children_left[current]\n",
    "            text.append(f' n{current} -- n{left};')\n",
    "            text.extend(self._dot_walker(left, known_data[index, :], known_labels[index]))\n",
    "\n",
    "            # And right one\n",
    "            right = tree.children_right[current]\n",
    "            text.append(f' n{current} -- n{right};')\n",
    "            text.extend(self._dot_walker(right, known_data[~index, :], known_labels[~index]))\n",
    "        else:\n",
    "            reds = np.sum(known_labels == Label.ANOMALY)\n",
    "            blacks = np.sum(known_labels == Label.REGULAR)\n",
    "            label = self._node_label(reds, blacks)\n",
    "            text.append(f' n{current} [label={label}];')\n",
    "\n",
    "        return text\n",
    "\n",
    "    def generate_dot(self):\n",
    "        text = []\n",
    "        text.append('graph \"\"')\n",
    "        text.append('{')\n",
    "        text.extend(self._dot_walker(0, self.known_data, self.known_labels))\n",
    "        text.append('}')\n",
    "\n",
    "        return '\\n'.join(text)\n",
    "\n",
    "    def draw_graph(self):\n",
    "        p = subprocess.Popen(['dot', '-Tpng'], stdin=subprocess.PIPE, stdout=subprocess.PIPE)\n",
    "        (image, _) = p.communicate(input=self.generate_dot().encode('utf-8'), timeout=10)\n",
    "        return image\n",
    "\n",
    "    def display_graph(self):\n",
    "        'Display tree as a graph'\n",
    "        image = self.draw_graph()\n",
    "\n",
    "        import IPython.display\n",
    "        return IPython.display.Image(image)\n",
    "\n",
    "    def display_2d(self, full=False):\n",
    "        'Display tree like a k2-tree'\n",
    "        data = self.data\n",
    "        labels = self.labels\n",
    "        known_data = self.known_data\n",
    "        known_labels = self.known_labels\n",
    "        if data is None:\n",
    "            raise ValueError('no data to plot')\n",
    "\n",
    "        if data.shape[1] != 2:\n",
    "            raise ValueError('only 2d plots are supported')\n",
    "\n",
    "        fig, ax = plt.subplots()\n",
    "\n",
    "        frame = np.empty((2, 2))\n",
    "        frame[0, :] = data.min(axis=0)\n",
    "        frame[1, :] = data.max(axis=0)\n",
    "\n",
    "        self._draw_subsets(ax, frame, 0)\n",
    "\n",
    "        if full:\n",
    "            ax.scatter(*data[labels == Label.R, :].T, color='blue', s=10, label='Normal')\n",
    "            ax.scatter(*data[labels == Label.A, :].T, color='red', s=10, label='Anomaluous')\n",
    "            ax.scatter(*known_data[known_labels == Label.R, :].T, color='blue', marker='*', s=80)\n",
    "            ax.scatter(*known_data[known_labels == Label.A, :].T, color='red', marker='*', s=80)\n",
    "        else:\n",
    "            ax.scatter(*known_data[known_labels == Label.R, :].T, color='blue', marker='*', s=80, label='Normal')\n",
    "            ax.scatter(*known_data[known_labels == Label.A, :].T, color='red', marker='*', s=80, label='Anomaluous')\n",
    "\n",
    "        ax.set(xlabel='x1', ylabel='x2', xlim=frame[:, 0], ylim=frame[:, 1])\n",
    "        ax.legend()\n",
    "\n",
    "        return fig, ax\n",
    "\n",
    "    def _draw_subsets(self, ax, frame, current):\n",
    "        tree = self.tree\n",
    "        if tree.feature[current] < 0:\n",
    "            return\n",
    "\n",
    "        threshold = tree.threshold[current]\n",
    "        feature = tree.feature[current]\n",
    "        if feature == 0:\n",
    "            ax.plot([threshold, threshold], [frame[0, 1], frame[1, 1]], color='gray', zorder=1)\n",
    "        else:\n",
    "            ax.plot([frame[0, 0], frame[1, 0]], [threshold, threshold], color='gray', zorder=1)\n",
    "\n",
    "        left_frame = frame.copy()\n",
    "        left_frame[1, feature] = threshold\n",
    "        self._draw_subsets(ax, left_frame, tree.children_left[current])\n",
    "\n",
    "        right_frame = frame.copy()\n",
    "        right_frame[0, feature] = threshold\n",
    "        self._draw_subsets(ax, right_frame, tree.children_right[current])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-14T14:28:42.766487Z",
     "start_time": "2024-03-14T14:28:42.665135Z"
    }
   },
   "outputs": [],
   "source": [
    "from coniferest.datasets import Label, non_anomalous_outliers\n",
    "\n",
    "data, labels = non_anomalous_outliers(inliers=1000, outliers=50)\n",
    "\n",
    "plt.figure()\n",
    "plt.title('Data overview')\n",
    "plt.scatter(*data[labels == Label.R, :].T, color='blue', s=10, label='Normal')\n",
    "plt.scatter(*data[labels == Label.A, :].T, color='red', s=10, label='Anomaluous')\n",
    "plt.xlabel('x1')\n",
    "plt.ylabel('x2')\n",
    "plt.legend()\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get some experiment data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-14T14:28:43.327089Z",
     "start_time": "2024-03-14T14:28:42.767906Z"
    }
   },
   "outputs": [],
   "source": [
    "from coniferest.session.oracle import create_oracle_session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-14T14:28:44.566929Z",
     "start_time": "2024-03-14T14:28:43.327770Z"
    }
   },
   "outputs": [],
   "source": [
    "from coniferest.isoforest import IsolationForest\n",
    "\n",
    "isoforest = IsolationForest(n_trees=30, n_subsamples=64, max_depth=5, random_seed=0)\n",
    "\n",
    "session_isoforest = create_oracle_session(\n",
    "    data=data,\n",
    "    labels=labels,\n",
    "    model=isoforest,\n",
    ").run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-14T14:28:46.065491Z",
     "start_time": "2024-03-14T14:28:44.567538Z"
    }
   },
   "outputs": [],
   "source": [
    "from coniferest.pineforest import PineForest\n",
    "\n",
    "pineforest = PineForest(\n",
    "    n_trees=30,\n",
    "    n_spare_trees=100,\n",
    "    n_subsamples=64,\n",
    "    max_depth=5,\n",
    "    random_seed=0,\n",
    ")\n",
    "\n",
    "session_pineforest = create_oracle_session(\n",
    "    data=data,\n",
    "    labels=labels,\n",
    "    model=pineforest,\n",
    ").run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Isoforest trees by themselves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-14T14:28:51.110035Z",
     "start_time": "2024-03-14T14:28:46.066387Z"
    }
   },
   "outputs": [],
   "source": [
    "for t in isoforest.trees:\n",
    "    viz = TreeViz(t, None, None, data=data, labels=np.full_like(labels, Label.R))\n",
    "    display(viz.display_graph())\n",
    "    fig, ax = viz.display_2d(full=True)\n",
    "    ax.legend().remove()\n",
    "    display(fig)\n",
    "    plt.close(fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-14T14:28:51.143752Z",
     "start_time": "2024-03-14T14:28:51.134023Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_trees_with_data(session, model):\n",
    "    for t in model.trees:\n",
    "        viz = TreeViz(\n",
    "            t,\n",
    "            known_data=session._data[np.array(list(session.known_labels.keys()))],\n",
    "            known_labels=np.array(list(session.known_labels.values())),\n",
    "            data=session._data,\n",
    "            labels=session._metadata,\n",
    "        )\n",
    "        display(viz.display_graph())\n",
    "        fig, _ = viz.display_2d()\n",
    "        display(fig)\n",
    "        plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pineforest trees on data seen by Pineforest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-14T14:29:00.544710Z",
     "start_time": "2024-03-14T14:28:51.148671Z"
    }
   },
   "outputs": [],
   "source": [
    "plot_trees_with_data(session_pineforest, pineforest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pineforest trees on data seen by Isoforest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-14T14:29:10.266755Z",
     "start_time": "2024-03-14T14:29:00.557372Z"
    }
   },
   "outputs": [],
   "source": [
    "plot_trees_with_data(session_isoforest, pineforest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Isoforest trees on data seen by Isoforest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-14T14:29:20.452848Z",
     "start_time": "2024-03-14T14:29:10.268378Z"
    }
   },
   "outputs": [],
   "source": [
    "plot_trees_with_data(session_isoforest, isoforest)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
